[JAX] Support for batched einsum and grouped GEMM without D2H memcpy#2604
Draft
jberchtold-nvidia wants to merge 67 commits intoNVIDIA:mainfrom
Draft
[JAX] Support for batched einsum and grouped GEMM without D2H memcpy#2604jberchtold-nvidia wants to merge 67 commits intoNVIDIA:mainfrom
jberchtold-nvidia wants to merge 67 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
- Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
This reverts commit bc6cf66.
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
… single-stream for multi tensor quantize)
Contributor
Greptile SummaryThis PR adds support for batched einsum operations and grouped GEMM without device-to-host memory copies, enabling efficient Mixture-of-Experts (MoE) implementations with per-expert FP8 quantization in JAX. Key Changes
TestingComprehensive test coverage includes:
Temporary Workarounds
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as JAX User Code
participant Einsum as einsum()
participant Dense as dense()
participant GEMM as GemmPrimitive
participant Quant as GroupedQuantize
participant CUDA as cuBLAS/CUDA
User->>Einsum: einsum("EBCM,EMH->EBCH", x, w, quantizer_sets)
Einsum->>Einsum: Parse equation & validate NN layout
Einsum->>Einsum: Stack quantizer_sets into pytree
Einsum->>Dense: vmap(dense_with_quantizer) over batch dim E
loop For each expert (vmapped)
Dense->>Quant: grouped_quantize(x, quantizer_i)
Quant->>CUDA: GroupedQuantizeFFI (batched)
CUDA-->>Quant: quantized tensors + scales
Dense->>Quant: grouped_quantize(w, quantizer_i)
Quant->>CUDA: GroupedQuantizeFFI (batched)
CUDA-->>Quant: quantized tensors + scales
Dense->>GEMM: gemm(x_q, w_q, scales)
GEMM->>CUDA: nvte_grouped_gemm (if batched)
Note over CUDA: GPU-side setup kernel<br/>No D2H memcpy
CUDA->>CUDA: cublasLtMatmul (grouped)
CUDA-->>GEMM: output
GEMM-->>Dense: result
end
Dense-->>Einsum: vmapped outputs
Einsum-->>User: final result
|
Comment on lines
405
to
406
| cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0, | ||
| outputs->size_bytes() - used_output_size, stream); |
Contributor
There was a problem hiding this comment.
style: potential pointer arithmetic issue with untyped data
the pointer arithmetic outputs->untyped_data() + used_output_size treats the pointer as char* (byte-addressed), which should be correct. verify that used_output_size is calculated in bytes, not elements.
Suggested change
| cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0, | |
| outputs->size_bytes() - used_output_size, stream); | |
| size_t used_output_size = (sum_group_sizes*non_group_m) * n * output_dtype_bytes; | |
| char* output_base = static_cast<char*>(outputs->untyped_data()); | |
| cudaMemsetAsync(output_base + used_output_size, 0, outputs->size_bytes() - used_output_size, stream); |
f58ba23 to
d799a29
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Depends on #2502
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: